import torch

m = torch.rand(3, 5)

rowsum = torch.sum(m, dim=-1)

print(m.shape)
print(m)
print(rowsum.shape)
print(rowsum)
print(rowsum.unsqueeze(-1).shape)
print(rowsum.unsqueeze(-1))


